from d2l import torch as d2l
from torch import nn

from config import *
from transformer import TransformerEncoder, TransformerDecoder, EncoderDecoder
from attention import sequence_mask
from data import loaddata, MyDataset, Lang
import torch.nn.functional as F


def loss_fn(out, tar):
    out = out.contiguous().view(-1, out.shape[-1])
    tar = tar.contiguous().view(-1)
    return F.cross_entropy(out, tar, ignore_index=2)  # pad


def train_sampling(net, data_iter, valid_iter, lr, num_epochs, lang, device, first_train=True, min_ppl=1e9, tg=0.5, tr=0.95, i=70):
    # """训练序列到序列模型"""
    def xavier_init_weights(m):
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
        if type(m) == nn.GRU:
            for param in m._flat_weights_names:
                if "weight" in param:
                    nn.init.xavier_uniform_(m._parameters[param])
    lang_len = len(lang)

    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-9)
    net.train()
    animator = Animator(xlabel='epoch', ylabel='loss',
                        xlim=[0, num_epochs])
    animator1 = Animator(xlabel='epoch', ylabel='ppl',
                         xlim=[0, num_epochs])
    LOSS = PPL = 0

    timer = Timer()

    if first_train:
        net.apply(xavier_init_weights)
    else:
        net.eval()
        PPL_AVG = 0
        LOSS_AVG = 0
        with torch.no_grad():
            #             i = 0
            for batch in valid_iter:
                # 转device
                X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
                # inputs带<eos>不带<bos>，outputs带<bos>不带<eos>
                # 初始bos
                bos = torch.tensor([2  # lang.word2index['<bos>']
                                    ] * Y.shape[0],
                                   device=device).reshape(-1, 1)
                dec_input = torch.cat([bos, Y[:, :-1]], 1)
                # Teacher forcing, 输入X和正确答案，希望得到正确答案
                # Y_hat
                Y_hat, _ = net(X, dec_input, X_valid_len)
                # Y_hat和Y中都有eos，没有bos
                l = loss_fn(Y_hat, Y)
                ppl = torch.exp(l)
                LOSS_AVG += l.item()
                PPL_AVG += ppl.item()
            PPL_AVG /= len(valid_iter)
            min_ppl = PPL_AVG
            LOSS_AVG /= len(valid_iter)

            animator1.add(0, (PPL_AVG, PPL_AVG,))
            animator.add(0, (LOSS_AVG, LOSS_AVG,))

    for epoch in range(num_epochs):
        for batch in data_iter:
            net.eval()
            with torch.no_grad():
                X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
                bos = torch.tensor([2] * Y.shape[0],
                                   device=device).reshape(-1, 1)
                dec_input = torch.cat([bos, Y[:, :-1]], 1)  # 强制教学
                Y_hat, _ = net(X, dec_input, X_valid_len)
                p = Y_hat.softmax(2)
                # p = F.cross_entropy(Y_hat.permute(0,2,1), Y, reduction='none')
                M, I = torch.max(p, dim=2)
                Y0 = I*((M>tg)*(M<tr)) + Y*(M<=tg) + (M>=tr)*torch.randint(0,lang_len,Y.shape,device=device)

            net.train()
            optimizer.zero_grad()
            bos = torch.tensor([2] * Y0.shape[0],
                               device=device).reshape(-1, 1)
            dec_input = torch.cat([bos, Y0[:, :-1]], 1)  # 强制教学
            Y_hat, _ = net(X, dec_input, X_valid_len)
            l = loss_fn(Y_hat, Y)
            ppl = torch.exp(l)
            l.backward()
            # grad_clipping(net, 2)
            optimizer.step()
            with torch.no_grad():
                LOSS += l.item()
                PPL += ppl.item()
        # i += 1
        #             if i == 30:
        #                 break
        LOSS /= len(data_iter)
        PPL /= len(data_iter)

        net.eval()
        PPL_AVG = 0
        LOSS_AVG = 0
        with torch.no_grad():
            #             i = 0
            for batch in valid_iter:
                # 转device
                X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
                # inputs带<eos>不带<bos>，outputs带<bos>不带<eos>
                # 初始bos
                bos = torch.tensor([2  # lang.word2index['<bos>']
                                    ] * Y.shape[0],
                                   device=device).reshape(-1, 1)
                dec_input = torch.cat([bos, Y[:, :-1]], 1)
                # Teacher forcing, 输入X和正确答案，希望得到正确答案
                # Y_hat
                Y_hat, _ = net(X, dec_input, X_valid_len)
                # Y_hat和Y中都有eos，没有bos
                l = loss_fn(Y_hat, Y)
                ppl = torch.exp(l)
                LOSS_AVG += l.item()
                PPL_AVG += ppl.item()
            #                 i += 1
            #                 if i == 10:
            #                     break

            PPL_AVG /= len(valid_iter)
            LOSS_AVG /= len(valid_iter)
            if PPL_AVG < min_ppl:
                if min_ppl != 1e9:
                    torch.save(encoder, MODEL_ROOT + "trans_encoder.mdl")
                    torch.save(decoder, MODEL_ROOT + "trans_decoder.mdl")
                min_ppl = PPL_AVG

        if (epoch + 1) % 1 == 0:
            animator1.add(epoch + 1, (PPL, PPL_AVG,))
            animator.add(epoch + 1, (LOSS, LOSS_AVG,))

    print(f'loss {LOSS:.3f}, PPL {PPL:.3f}, {timer.stop() / 60:.1f} '
          f'min on {str(device)}')
    print(f'loss {LOSS_AVG:.3f}, PPL {PPL_AVG:.3f} ')
    print(lr)




if __name__ == "__main__":
    dataset = torch.load(DATA_ROOT + "dataset{}".format(num_examples))
    lang = torch.load(DATA_ROOT + "lang{}".format(num_examples))
    valid_data = torch.load(DATA_ROOT + "validdata")
    train_iter = loaddata(dataset, 64)
    valid_iter = loaddata(valid_data, 64)
    print(len(lang))

    # 初始化，词表大小，kqv，hidden，
    # encoder = TransformerEncoder(
    #     len(lang), key_size, query_size, value_size, num_hiddens,
    #     norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
    #     num_layers, dropout)
    # decoder = TransformerDecoder(
    #     len(lang), key_size, query_size, value_size, num_hiddens,
    #     norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
    #     num_layers, dropout)
    encoder = torch.load(MODEL_ROOT + "trans_encoder{}.mdl".format(num_examples))
    decoder = torch.load(MODEL_ROOT + "trans_decoder{}.mdl".format(num_examples))

    net = EncoderDecoder(encoder, decoder)
    train_sampling(net, train_iter, valid_iter, lr, num_epochs, lang, device, first_train=False)

    torch.save(encoder, MODEL_ROOT + "trans_encoder{}.mdl".format(num_examples))
    torch.save(decoder, MODEL_ROOT + "trans_decoder{}.mdl".format(num_examples))
